
#### HOUSEKEEPING ####

rm(list = ls())

## PACKAGES
# install_github("cran/LowRankQP")
package_list <- c("data.table", "tidyverse")
lapply(package_list, require, character.only = TRUE)

## PARAMS
n_job  <- 250
date   <- "final"
lambda <- fread(file = paste0("results/pensynth/", date, "/inputs/lambda.csv"))$lambda

## DIRECTORIES
source("R/00_paths.R")
out_dir <- paste0(res_dir, "pensynth/", date, "/")
l_print <- format(lambda, scientific = TRUE)
if (!dir.exists(paste0(out_dir, "fixed/", l_print, "/"))) {
  dir.create(paste0(out_dir, "fixed/", l_print, "/"))
}


#### FUNCTIONS ####

## PENSYNTH FUNCTIONS
fun_list <- c("wsoll1", "regsynth", "regsynthpath", "pensynth_parallel", "TZero", "estimator_matching", "get_stats")
fun_list <- paste0(ps_dir, "/functions/", fun_list, ".R")
lapply(fun_list, source)

#### ~ GROUP SETUP FUNCTION ####
format_group <- function(df, group, X_vars, y_var, treat_var) {
  
  df <- df[df$group==group,]
  
  d           <- df[[treat_var]]
  y           <- df[[y_var]]
  X           <- data.frame(df[, ..X_vars])
  X_unscaled  <- X
  X[, X_vars] <- mapply(function(x) X[,x]/sd(X[d==1,x]), X_vars) # rescale
  
  X  <- as.matrix(X)
  X0 <- t(X[d==0,])
  X1 <- t(X[d==1,])
  Y0 <- y[d==0]
  Y1 <- y[d==1]
  V  <- diag(ncol(X))
  
  X0_unique  <- as.data.table(cbind(Y0,t(X0)))
  X0_unique  <- X0_unique[,list(Y0_average = mean(Y0)), X_vars]
  Y0_average <- as.vector(X0_unique[,Y0_average])
  X0_unique  <- t(as.matrix(X0_unique[,..X_vars]))
  
  X0_ids <- df$customer_key[d==0]
  X1_ids <- df$customer_key[d==1]
  
  return(list(d = d,
              y = y,
              X = X,
              X_unscaled = X_unscaled,
              X0 = X0,
              X1 = X1,
              Y0 = Y0,
              Y1 = Y1,
              V = V,
              X0_unique = X0_unique,
              Y0_average = Y0_average,
              X0_ids = X0_ids,
              X1_ids = X1_ids))
  
}

#### SET UP ####

task_id     <- as.integer(Sys.getenv("SLURM_ARRAY_TASK_ID"))
task_id     <- ifelse(is.na(task_id), 100, task_id)

## LOAD DATA

if (task_id==0) {
  df   <- fread(file = paste0(out_dir, "inputs/data.csv"))
} else {
  df   <- fread(file = paste0(out_dir, "inputs/resamp/data_", task_id, ".csv"))
}
X_vars <- fread(file = paste0(out_dir, "inputs/X_vars.csv"))$X_vars

groups  <- sort(unique(df$group))
setup   <- list()
grid_df <- data.frame(
  Group = as.numeric(),
  Index = as.numeric()
)

for (g in groups) {
  setup[[g]] <- format_group(df = df[df$group==g,], 
                             X_vars = X_vars, 
                             y_var = "mean_gb_post", 
                             treat_var = "treat")
  group_df   <- data.frame(
    Group = g,
    Index = 1:length(setup[[g]]$Y1)
  )
  grid_df    <- rbind(grid_df, group_df)
  rm(group_df)
}

## PRE-PERIOD TREATED CHARACTERISTICS
sp_df <- fread(paste0(dat_dir, "ubp_outcomes.csv")) %>%
  inner_join(df %>% dplyr::select(customer_key))

## CALCULATE BILL
tier_allow <- NA
p_tu       <- NA
q_tu       <- NA
p_i        <- NA
p_t        <- NA
days_in_m  <- NA
max_ovr    <- NA
month_start <- NA  # First month index for counterfactual calculations
month_end   <- NA  # Last month index for counterfactual calculations
n_tier      <- NA  # Number of service tiers

calc_bill <- function(gb,tier,vid,pref="",suf="") {
  en_allow         <- tier_allow[tier]
  en_ovr           <- ifelse(days_in_m*gb <= en_allow, 0,
                             p_tu * ceiling((days_in_m*gb - en_allow)/q_tu))
  en_ovr           <- ifelse(en_ovr>max_ovr, max_ovr, en_ovr)
  bill_i           <- p_i[tier]
  bill_t           <- p_t*vid
  ovr              <- en_ovr
  out_df           <- data.frame(bill_i, bill_t, ovr)
  colnames(out_df) <- paste0(pref, colnames(out_df), suf)
  
  return(out_df)
}

sp_df <- sp_df %>% cbind(calc_bill(sp_df$en_tot_gb,sp_df$en_tier,sp_df$en_vid,pref="en_"))
sp_df <- sp_df %>% cbind(calc_bill(0,sp_df$st_tier,sp_df$st_vid,pref="st_"))

# Calculate bills for counterfactual months
for (m in month_start:month_end) {
  sp_df <- sp_df %>% cbind(calc_bill(sp_df[[paste0("tot_gb_",m)]],
                                      sp_df[[paste0("en_tier_",m)]],
                                      sp_df[[paste0("en_vid_",m)]],
                                      pref="en_", suf=paste0("_",m)))
}

#### RESULTS ####

st_t  <- 1
st_c  <- 1
w_all <- matrix(data=0, nrow=sum(df$treat), ncol = sum(1-df$treat))

for (g in groups) {
  
  print(g)
  
  g_df <- grid_df %>% filter(Group==g) %>% mutate(customer_key = setup[[g]]$X1_ids)
  
  ## PARAMS
  cf_samp_n  <- 5000
  nt         <- sum(setup[[g]]$d)
  nc         <- sum(1-setup[[g]]$d)
  
  if (task_id==0) {
    file_list <- list.files(paste0(out_dir, "fixed/", l_print), full.names = TRUE)
  } else {
    file_list <- list.files(paste0(out_dir, "fixed/", l_print, "_", "SE", "/", 
                                   sprintf("%03d", task_id)), full.names = TRUE)
  }
  file_list  <- file_list[grep(paste0(g, "_"), file_list)]
  test_df    <- do.call(cbind, lapply(file_list, 
                                      function(x){df <- fread(x); 
                                      names(df) = str_extract(x, "(?<=_).*?(?=\\.csv)"); 
                                      return(df)}))
  
  names(test_df) <- sapply(str_split(names(test_df),"_"), function(x) x[2])
  
  test_df   <- test_df %>% dplyr::select(as.character(sort(as.numeric(names(test_df)))))
  # test      <- test_df[1:nc,1:nt]
  Wsol_RMSE <- as.matrix(t(test_df))
  
  en_t <- st_t + sum(setup[[g]]$d)-1
  en_c <- st_c + sum(1-setup[[g]]$d)-1
  w_all[st_t:en_t,st_c:en_c] <- Wsol_RMSE
  st_t <- en_t + 1
  st_c <- en_c + 1
  
}


nw      <- ncol(w_all)
nwpos   <- rowSums(w_all>0)
wposmax <- apply(w_all,1,max)
wmaxid  <- apply(w_all,1,which.max)
nt      <- nrow(w_all)

## EXPECTED OVERAGES

tier  <- df[df$treat==1,]$svc_tier
tier2 <- df[df$treat==1,]$svc_tier+1
tier2 <- ifelse(tier2 > n_tier, n_tier, tier2)
allow <- tier_allow[tier]
allow2 <- tier_allow[tier2]

cf_usage <- df[df$treat==0,]$mean_gb_post
cf_samp  <- matrix(data = 0, nrow = cf_samp_n, ncol = nt)

for (it in 1:nt) {
  if (nwpos[it]==1) {
    cf_samp[,it] <- cf_usage[wmaxid[it]]
  } else {
    cf_samp[,it] <- sample(cf_usage,
                           size = cf_samp_n,
                           replace = TRUE,
                           prob = w_all[it,])
  }
}

cf_ovr  <- matrix(data = 0, nrow = cf_samp_n, ncol = nt)
cf_ovr2 <- matrix(data = 0, nrow = cf_samp_n, ncol = nt)
for (it in 1:nt) {
  cf_ovr[,it]  <- ifelse(cf_samp[,it] <= allow[it], 0,
                         p_tu * ceiling((cf_samp[,it] - allow[it])/q_tu))
  cf_ovr2[,it] <- ifelse(cf_samp[,it] <= allow2[it], 0,
                        p_tu * ceiling((cf_samp[,it] - allow2[it])/q_tu))
}

e_ovr  <- apply(cf_ovr,2,mean)
e_ovr2 <- apply(cf_ovr2,2,mean)

grid_df$e_ovr     <- e_ovr
grid_df$e_ovr_upg <- e_ovr2

## COUNTERFACTUALS
c_ids  <- df$customer_key[df$treat==0]
t_ids  <- df$customer_key[df$treat==1]
csp_df <- sp_df %>% filter(customer_key %in% c_ids)

cf_df <- data.frame(
  cf_upg_tier      = as.numeric(w_all %*% csp_df$upg_tier),
  cf_dng_tier      = as.numeric(w_all %*% csp_df$dng_tier),
  cf_add_vid       = as.numeric(w_all %*% csp_df$add_vid),
  cf_drop_vid      = as.numeric(w_all %*% csp_df$drop_vid),
  cf_del_gb        = as.numeric(w_all %*% csp_df$del_gb),
  cf_del_video     = as.numeric(w_all %*% csp_df$del_video),
  cf_del_browsing  = as.numeric(w_all %*% csp_df$del_browsing),
  cf_del_other     = as.numeric(w_all %*% csp_df$del_other),
  cf_del_netflix   = as.numeric(w_all %*% csp_df$del_netflix),
  cf_del_youtube   = as.numeric(w_all %*% csp_df$del_youtube),
  cf_del_hulu      = as.numeric(w_all %*% csp_df$del_hulu),
  cf_del_slingtv   = as.numeric(w_all %*% csp_df$del_slingtv),
  cf_st_gb         = as.numeric(w_all %*% csp_df$st_tot_gb),
  cf_st_video      = as.numeric(w_all %*% csp_df$st_gb_video),
  cf_st_browsing   = as.numeric(w_all %*% csp_df$st_gb_browsing),
  cf_st_other      = as.numeric(w_all %*% csp_df$st_gb_other),
  cf_st_netflix    = as.numeric(w_all %*% csp_df$st_gb_netflix),
  cf_st_youtube    = as.numeric(w_all %*% csp_df$st_gb_youtube),
  cf_st_hulu       = as.numeric(w_all %*% csp_df$st_gb_hulu),
  cf_st_slingtv    = as.numeric(w_all %*% csp_df$st_gb_slingtv),
  cf_st_bill_i     = as.numeric(w_all %*% csp_df$st_bill_i),
  cf_st_bill_t     = as.numeric(w_all %*% csp_df$st_bill_t),
  cf_st_ovr        = as.numeric(w_all %*% csp_df$st_ovr),
  cf_en_gb         = as.numeric(w_all %*% csp_df$en_tot_gb),
  cf_en_video      = as.numeric(w_all %*% csp_df$en_gb_video),
  cf_en_browsing   = as.numeric(w_all %*% csp_df$en_gb_browsing),
  cf_en_other      = as.numeric(w_all %*% csp_df$en_gb_other),
  cf_en_netflix    = as.numeric(w_all %*% csp_df$en_gb_netflix),
  cf_en_youtube    = as.numeric(w_all %*% csp_df$en_gb_youtube),
  cf_en_hulu       = as.numeric(w_all %*% csp_df$en_gb_hulu),
  cf_en_slingtv    = as.numeric(w_all %*% csp_df$en_gb_slingtv),
  cf_en_bill_i     = as.numeric(w_all %*% csp_df$en_bill_i),
  cf_en_bill_t     = as.numeric(w_all %*% csp_df$en_bill_t),
  cf_en_ovr        = as.numeric(w_all %*% csp_df$en_ovr)
)

# Add month-specific counterfactual variables
for (m in month_start:month_end) {
  cf_df[[paste0("cf_en_bill_i_", m)]] <- as.numeric(w_all %*% csp_df[[paste0("en_bill_i_", m)]])
  cf_df[[paste0("cf_en_bill_t_", m)]] <- as.numeric(w_all %*% csp_df[[paste0("en_bill_t_", m)]])
  cf_df[[paste0("cf_en_ovr_", m)]]    <- as.numeric(w_all %*% csp_df[[paste0("en_ovr_", m)]])
}

tsp_df  <- sp_df %>% filter(customer_key %in% t_ids)
grid_df <- grid_df %>% cbind(cf_df) %>% cbind(tsp_df %>% select(-customer_key))

fwrite(x = grid_df, file = paste0(out_dir, "fixed/", l_print, "_", "SE", "/", sprintf("%03d", task_id), "-cf.csv"))

